{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OnbjvoqHM_h5"
      },
      "source": [
        "# Fine-tune the XLM-T multilingual Twitter language model for elite criticism detection\n",
        "\n",
        "*author:* Hauke Licht\\\n",
        "\n",
        "In this notebook, I fine-tune the pre-trained XLM-T multilingual Twitter language model provided by [Francesco Barbieri, Luis Espinosa Anke, Jose Camacho-Collados](https://arxiv.org/abs/2104.12250) for detecting elite criticism (anti-elite messages) in poltical parties' tweets.\n",
        "\n",
        "XLM-T is a *crosslingual* language model (hence **X**LM) with an XLM-R architecture ([Conneau et al. 2019](https://arxiv.org/abs/1911.02116)) that has been pre-trained on a large, multilingual corpus of tweets.\n",
        "I use this model to train an 'adapter' (i.e., classification head) for classifying tweets posted by political parties according to whether or not they contain elite-critical statements.\n",
        "\n",
        "The labeled dataset I use records 5.3K+ tweets that have been sampled from tweets posted by political parties from 20 Western countries between 2008 and early 2021.\n",
        "The annotations come from 6 crowd coders per tweet that I have aggregated into tweet-level labels using a Dawid and Skene ([1979](https://doi.org/10.2307/2346806)) annotation model (cf. [Paun et al. 2018](https://aclanthology.org/Q18-1040.pdf)).\n",
        "\n",
        "**_Source:_** This notebook is adapted from the [fine-tuning notebook](https://colab.research.google.com/drive/1IAA1h8u53O1hi9807u7oOFuT3728N0-n) provided by the creators of XLM-T.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tmrIPXRXg9lC"
      },
      "source": [
        "# Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rTvbV_CepVpy"
      },
      "outputs": [],
      "source": [
        "# the model name in the huggingface model hub\n",
        "MODEL = 'cardiffnlp/twitter-xlm-roberta-base'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AYOhzvayKMq7"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "base_path = os.path.join('..', '..')\n",
        "data_path = os.path.join(base_path, 'data')\n",
        "input_path = os.path.join(data_path, 'intermediate', 'training')\n",
        "res_dir = os.path.join(data_path, 'output', 'classifier_results')\n",
        "os.makedirs(res_dir, exist_ok = True)\n",
        "fits_path = os.path.join(data_path, 'fits') "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1uQ5JgSHg_VQ"
      },
      "source": [
        "Install required packages:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nKftOu9fyC8R"
      },
      "outputs": [],
      "source": [
        "# %%capture\n",
        "# !pip3 install --upgrade pip\n",
        "# !pip3 install sentencepiece==0.1.96\n",
        "# !pip3 install datasets==2.2.2\n",
        "# !pip3 insatll tokenizers==0.12.1 # on Mac you need to have rust installed\n",
        "# !pip3 install transformers==4.19.2\n",
        "# !pip3 install scikit-learn=1.2.2"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8X7ISk4khCWR"
      },
      "source": [
        "Load modules:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Y5f1fFbETSbM"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import shutil\n",
        "import gc\n",
        "import random\n",
        "from tqdm.notebook import tqdm\n",
        "import re\n",
        "import json\n",
        "from typing import Union\n",
        "\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "\n",
        "from transformers import (\n",
        "    AutoTokenizer,\n",
        "    AutoModelForSequenceClassification,\n",
        "    # for training\n",
        "    Trainer,\n",
        "    TrainingArguments,\n",
        "    EarlyStoppingCallback,\n",
        "    # for reproducibility (see https://github.com/huggingface/transformers/pull/16907)\n",
        "    set_seed,\n",
        "    enable_full_determinism\n",
        ")\n",
        "\n",
        "import torch\n",
        "from torch import nn\n",
        "\n",
        "from sklearn.metrics import classification_report, precision_recall_fscore_support"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S3clKYWdhY9Q"
      },
      "source": [
        "Setup for reproducibility:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "I26TKRr1hLTG"
      },
      "outputs": [],
      "source": [
        "SEED = 1234\n",
        "set_seed(SEED)\n",
        "enable_full_determinism(SEED)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XWqZ7LGMFeHV"
      },
      "source": [
        "# Data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eQWLAKg4aSeG"
      },
      "source": [
        "## Description"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BqBm2yyeYta4"
      },
      "source": [
        "The dataset we'll load has the following columns:\n",
        "\n",
        "- `item_id` (str): Unique ID of tweet (has been constructed by concatenating ISO-3-character country code of the party posting the tweet, `user_id`, and `status_id`)\n",
        "- `user_id` (int): the ID of the account that has posted the tweet\n",
        "- `status_id` (int): the ID of the tweet\n",
        "- `labeling` (str): the label class a tweet has been assigned to (i.e., its label)\n",
        "- `text` (str): The tweet's text (in its original language)\n",
        "- `test_` (bool): Boolean flag indicating tweets that should in the test (not the training) data split\n",
        "\n",
        "Note that user and status IDs are integers because they can be very long.\n",
        "Hence, I'll read them as int64 types to ensure they are not corrputed.\n",
        "(Alternatively, you could just treat them as strings.)\n",
        "\n",
        "To this end, I create a dictionary mapping column names to the desired data types that I'll pass when reading the CSV file:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Wrj0Mxe2am0W"
      },
      "outputs": [],
      "source": [
        "col_types = {\n",
        "  'item_id': str,\n",
        "  'user_id': 'Int64',\n",
        "  'status_id': 'Int64',\n",
        "  'labeling': str,\n",
        "  'text': str,\n",
        "  'test_': bool\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rP65yYzRYLsb"
      },
      "source": [
        "## Download\n",
        "\n",
        "Download and read the labeled tweets dataset:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DsmH1rS_XKgr"
      },
      "outputs": [],
      "source": [
        "fp = os.path.join(input_path, 'training_data_pooled_samples.csv')\n",
        "dat = pd.read_csv(fp, sep = ',', dtype = col_types)\n",
        "dat.head()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "btNyCm4RYfiq"
      },
      "source": [
        "Set unique IDs as index."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MW-9z7OLYPpL"
      },
      "outputs": [],
      "source": [
        "dat.set_index('item_id', inplace = True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n9kI6uU8cXP6"
      },
      "source": [
        "## Create binary labels\n",
        "\n",
        "Let's have a look at the `labeling` values:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "s6gLkIe8cbi0",
        "outputId": "78784c63-08cf-4ca8-f017-c3c6edcd2c8a"
      },
      "outputs": [],
      "source": [
        "dat.labeling.value_counts()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dw4127o6chWY"
      },
      "source": [
        "The labelings indicates whether a tweet contains\n",
        "\n",
        "1. **no** elite criticism,\n",
        "2. elite criticism directed at **the elite in general**, or\n",
        "3. criticism of **specific elites**.\n",
        "\n",
        "We argue that *the essence of anti-elite rhetoric* (as a political strategy) is generalized elite criticism.\n",
        "Hence, we are mainly interested in the distinction between 'general' elite criticism and all other statements.\n",
        "Accordingly, I create a **binary** label indicator."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5Et8jtZGdPM7",
        "outputId": "6c18f403-d891-455b-b2c1-74039f20da11"
      },
      "outputs": [],
      "source": [
        "dat['label_'] = dat.labeling == 'yes-general' # positive (negative) class label => True (False)\n",
        "dat['label_'] = dat['label_'].astype(int)  # positive (negative) class label => 1 (0)\n",
        "dat.label_.value_counts()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gK7wm27EkKgH"
      },
      "source": [
        "## Preprocess texts\n",
        "\n",
        "Since URLs and account handles (`@...`) have been removed before when pre-training the XLM-T model, we'll replicate these preprocessing steps.\n",
        "I use a function provided of the model createors (from [here](https://huggingface.co/cardiffnlp/twitter-xlm-roberta-base)).\n",
        "However, I add\n",
        "\n",
        "1. space to white space conversion (to remove tabs and line breaks inside tweets\n",
        "2. I improve the replacement of account handels\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XUvzHYssk2n-"
      },
      "outputs": [],
      "source": [
        "def preprocess_tweet_text(text, handle_regex = r'@[A-Za-z0-9_]{4,15}'):\n",
        "  # convert spaces\n",
        "  text = re.sub(r'\\s+', u'\\x20', text)\n",
        "  new_text = []\n",
        "  for t in text.split(u'\\x20'):\n",
        "    t = re.sub(handle_regex, '@user', t) if re.search(handle_regex, t) and len(t) > 1 else t\n",
        "    t = 'http' if t.startswith('http') else t\n",
        "    new_text.append(t)\n",
        "  return u'\\x20'.join(new_text)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "WuIAgBoThb_F",
        "outputId": "cefe41c5-cc20-4bc5-d4f1-6e76b63e380e"
      },
      "outputs": [],
      "source": [
        "# test\n",
        "for tweet in dat.text.values[:5]:\n",
        "  print(tweet)\n",
        "  print('=>', preprocess_tweet_text(tweet))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UrxL1YIbmiKz"
      },
      "source": [
        "Preprocess all texts:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1G81jyac_SVC"
      },
      "outputs": [],
      "source": [
        "dat['text'] = dat.text.apply(preprocess_tweet_text)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bAMekJn7kFVU"
      },
      "source": [
        "## Split into training, validation, and test data sets"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t4zQ6tGSb7iN"
      },
      "source": [
        "Now we can split the dataset into the training and test partitions.\n",
        "To do so, we use the `test_` indicator column that comes with the dataset:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "7LGMOnxAbwpP",
        "outputId": "c9d6ce5c-504e-44fe-986d-e90238eeea76"
      },
      "outputs": [],
      "source": [
        "train_dat = dat[~dat.test_]\n",
        "print(f'No. train samples: {len(train_dat)}; pos. label proportion: {train_dat.label_.values.mean():.3f}')\n",
        "test_dat = dat[dat.test_]\n",
        "print(f'No. test samples:  {len(test_dat)}; pos. label proportion: {test_dat.label_.values.mean():.3f}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Rbdmvezh_-la"
      },
      "source": [
        "## Convert test data to `Dataset` instance"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Q9jBvb5yAR0D"
      },
      "source": [
        "### Subset to text text--label tuples"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qORfLANSAQ7O"
      },
      "outputs": [],
      "source": [
        "test_dataset = dict(text = test_dat.text.values.tolist(), labels = test_dat.label_.values.tolist())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l0dZXcR4m5rB"
      },
      "source": [
        "### Encode texts\n",
        "\n",
        "The XLM-T model comes with a pre-trained (subword) tokenizer.\n",
        "This tokenizers converts tweets' texts into sequences of integers that index the model's vocabulary.\n",
        "This conversion of tweet texts into numeric representation --- called 'encoding' --- is necessary for fine-tuning.\n",
        "\n",
        "So we download and intialize the pre-trained tokenizer and apply it to the texts in the test data split ..."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1IjMOsNSyC8d"
      },
      "outputs": [],
      "source": [
        "tokenizer = AutoTokenizer.from_pretrained(MODEL, use_fast=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RX-l5ZiKD-xN"
      },
      "source": [
        "... and apply it to encode the texts in the test data split:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "WGfjnUAkACgq",
        "outputId": "8a54648e-4e12-4b9a-ad7c-b05aee087c1e"
      },
      "outputs": [],
      "source": [
        "test_encodings = tokenizer(test_dataset['text'], truncation=True, padding=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "puaXrohABgtj"
      },
      "source": [
        "For example, thee text of the first tweet in the test data split looks as follows:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "gVxil2W_A5gn",
        "outputId": "b2109fb1-76df-4003-efb1-682edb63e836"
      },
      "outputs": [],
      "source": [
        "print(test_dataset['text'][0])\n",
        "for tok_id, tok in zip(test_encodings[0].ids, test_encodings[0].tokens):\n",
        "  if tok_id == tokenizer.pad_token_id:\n",
        "    break\n",
        "  print(tok_id, '\\t', tok)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_s2U4gbjCBvo"
      },
      "source": [
        "### Convert to custom dataset instance\n",
        "\n",
        "huggingface's Trainer is operating on (iterable) torch Datasets.\n",
        "\n",
        "Hence, we define our custom dataset class:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1MzkHFG5yC8f"
      },
      "outputs": [],
      "source": [
        "class TweetsDataset(torch.utils.data.Dataset):\n",
        "  '''\n",
        "  Custom Tweets datasets class\n",
        "\n",
        "  Parameters:\n",
        "  -----------\n",
        "  encodings : 'transformers.tokenization_utils_base.BatchEncoding' instance of\n",
        "      tokenized and encoded texts\n",
        "  labels : list or 1d-array recording integer label indicators\n",
        "  '''\n",
        "  def __init__(self, encodings, labels):\n",
        "    self.encodings = encodings\n",
        "    self.labels = labels\n",
        "\n",
        "  def __getitem__(self, idx):\n",
        "    item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n",
        "    item['labels'] = torch.tensor(self.labels[idx])\n",
        "    return item\n",
        "\n",
        "  def __len__(self):\n",
        "    return len(self.labels)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PszgMb0UMUqh"
      },
      "outputs": [],
      "source": [
        "test_dataset = TweetsDataset(encodings=test_encodings, labels=test_dataset['labels'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z_BTQBaJyC8g"
      },
      "source": [
        "# Fine-tuning"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ccU5vewTKpil"
      },
      "source": [
        "## Fixed hyper-parameters"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "REu85dxgJXMV"
      },
      "outputs": [],
      "source": [
        "NUM_LABELS = 2\n",
        "LR = 2e-5\n",
        "EPOCHS = 8\n",
        "BATCH_SIZE = 32"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DNmbGSqRJ5Mc"
      },
      "source": [
        "Setup the Trainer arguments (can be re-used across training runs):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PGuho0dMyC8g"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "\n",
        "training_args = TrainingArguments(\n",
        "  # for (temporary) storage\n",
        "  output_dir='./temp',                            # output directory\n",
        "  num_train_epochs=EPOCHS,                       # total number of training epochs\n",
        "  per_device_train_batch_size=BATCH_SIZE,        # batch size per device during training\n",
        "  per_device_eval_batch_size=BATCH_SIZE*2,       # batch size for evaluation\n",
        "  warmup_steps=100,                              # number of warmup steps for learning rate scheduler\n",
        "  weight_decay=0.01,                             # strength of weight decay\n",
        "  fp16=torch.cuda.is_available(),\n",
        "  use_mps_device=False,\n",
        "  # model evaluation\n",
        "  evaluation_strategy='epoch',                   # how (and when) to evaluate\n",
        "  metric_for_best_model='f1_pos',                # hard code eval metric\n",
        "  greater_is_better=True,\n",
        "  # reproducibility\n",
        "  full_determinism=True,\n",
        "  seed=SEED,\n",
        "  data_seed=SEED,\n",
        "  # from https://discuss.huggingface.co/t/save-only-best-model-in-trainer/8442/8\n",
        "  save_strategy='epoch',                            # if and when to save (must match `evaluation_strategy`)\n",
        "  save_total_limit=2,\n",
        "  load_best_model_at_end=True                   # load best model at the end?\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TKjsRTTJEkZm"
      },
      "source": [
        "## Define helper functions and classes"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z7VDiGGU0Gnj"
      },
      "source": [
        "To fine-tune a XLM-T model on our data, we have to prepare the data in a number of ways:\n",
        "\n",
        "1. convert our sets of label--text tuples ('samples') into iterable datasets\n",
        "2. 'encode' tweets' texts with the XLM-T'S built-in pre-trained tokenizer\n",
        "3. collect train and validation samples in separate `TweetsDataset` instances\n",
        "\n",
        "And if we want to apply $k$-fold cross validation to select optimal hyperparameters, we need to repeat these steps for each of $k$ repetitions."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iJoToQRqEmDq"
      },
      "source": [
        "#### Function for creating a dataset\n",
        "\n",
        "The helper function below wraps uses the pre-trained tokenizer and our custom `TweetsDataset` classes from above to create a dataset:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4zvAhR8sEXGF"
      },
      "outputs": [],
      "source": [
        "def create_dataset(texts, labels):\n",
        "  '''Creates a TweetsDataset instance'''\n",
        "  return TweetsDataset(tokenizer(texts, truncation=True, padding=True), labels)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M1objO8-FItE"
      },
      "source": [
        "### Function for intializing a model instance"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "wLccW47KFfRA",
        "outputId": "0cce219f-6ad5-4c03-c752-d9cbf3b9a55d"
      },
      "outputs": [],
      "source": [
        "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
        "device = torch.device(device)\n",
        "device"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "swibVJ-4XjE8"
      },
      "outputs": [],
      "source": [
        "def model_init():\n",
        "  return AutoModelForSequenceClassification.from_pretrained(MODEL, num_labels=NUM_LABELS)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kiuyvmfb5n4p"
      },
      "source": [
        "### Custom Trainer class with weighted loss\n",
        "\n",
        "\n",
        "Note that the positive label class (general elite criticism) is outnumbered 4:1 in the training, validation, and test data splits.\n",
        "This will lead to very poor performance in the positive label class – the class we are mainly interested in (see [here](https://discuss.huggingface.co/t/precision-vs-recall-when-using-transformer-models/13448)).\n",
        "\n",
        "I subclass the Trainer to use a weighted loss (see [here](https://discuss.huggingface.co/t/how-can-i-use-class-weights-when-training/1067/7), [here](https://huggingface.co/docs/transformers/main/main_classes/trainer), and [here](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bCFBsXACsYsL"
      },
      "outputs": [],
      "source": [
        "class ClassWeightsTrainer(Trainer):\n",
        "  '''custom Trainer class'''\n",
        "  def __init__(self, device: Union[torch.device,str], class_weights: Union[list, dict], **kwargs):\n",
        "    super().__init__(**kwargs)\n",
        "    self.device = device if isinstance(device, torch.device) else torch.device(device)\n",
        "    if len(class_weights) != self.model.config.num_labels:\n",
        "      raise ValueError(f'length of `class_weights` must be {self.model.config.num_labels}')\n",
        "    if isinstance(class_weights, dict):\n",
        "      if set(class_weights.keys()) != set(self.model.config.id2label.keys()):\n",
        "        raise ValueError(f'keys of `class_weights` mismatch label classes {list(self.model.config.id2label.keys())}')\n",
        "      class_weights = [v for k, v in sorted(class_weights.items(), key=lambda item: item[1])]\n",
        "    self.class_weights = torch.tensor(class_weights, dtype=self.model.dtype)\n",
        "    if str(self.device) != 'cpu':\n",
        "      self.model.to(self.device);\n",
        "      self.class_weights = self.class_weights.to(self.device)\n",
        "\n",
        "  def compute_loss(self, model, inputs, return_outputs=False):\n",
        "    labels = inputs.get('labels')\n",
        "    # forward pass\n",
        "    outputs = model(**inputs)\n",
        "    logits = outputs.get('logits')\n",
        "    if str(self.device) == 'cpu':\n",
        "      logits, labels = logits.cpu(), labels.cpu()\n",
        "    # compute custom loss\n",
        "    loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)\n",
        "    loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))\n",
        "    return (loss, outputs) if return_outputs else loss\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6mpV7_aYFNR2"
      },
      "source": [
        "### Function for computing evaluation metrics"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MGb54RE0boCv"
      },
      "outputs": [],
      "source": [
        "metrics = [m + '_' + suffix for suffix in ['macro', 'micro', 'pos', 'neg'] for m in ['f1', 'prec', 'recl']]\n",
        "\n",
        "def compute_metrics(eval_pred, metrics=metrics):\n",
        "  '''compute evaluation metrics'''\n",
        "  preds, labs = eval_pred\n",
        "  preds = np.argmax(preds, axis=1)\n",
        "\n",
        "  res = dict()\n",
        "  res['prec_macro'], res['recl_macro'], res['f1_macro'], _ = precision_recall_fscore_support(labs, preds, average = 'macro')\n",
        "  res['prec_micro'], res['recl_micro'], res['f1_micro'], _ = precision_recall_fscore_support(labs, preds, average = 'micro')\n",
        "  (res['prec_neg'], res['prec_pos']), (res['recl_neg'], res['recl_pos']), (res['f1_neg'], res['f1_pos']), _ = precision_recall_fscore_support(labs, preds, average = None)\n",
        "\n",
        "  return {m: res[m] for m in metrics}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "waCETbJwJmGn"
      },
      "source": [
        "### Function for extracting evaluation training log from Trainer instance"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lDxVjCdsf0Zh"
      },
      "outputs": [],
      "source": [
        "def get_train_history(trainer):\n",
        "  log = pd.DataFrame([step for step in trainer.state.log_history if 'eval_f1_macro' in step.keys()])\n",
        "  return(log)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qL4RpA_pFSxm"
      },
      "source": [
        "## Select 'optimal' class weights trough 5-fold cross validation\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NIPN12V4KzGM"
      },
      "source": [
        "Now, we define the class weighting schemes we want to evaluate.\n",
        "I include the default 50:50 class weights and add schemes that up-weigh the positive (minority) class label 4:1, 8:1, and 16:1, respectively."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "aLsokSwH2cHL",
        "outputId": "bfe2eb3b-f28d-4a77-a67e-5e09a180be0a"
      },
      "outputs": [],
      "source": [
        "# define grid of class weight values\n",
        "nws = 0.5**np.asarray(range(1,5)) # [0.5, 0.25, 0.125, 0.0625]\n",
        "pws = 1-nws\n",
        "class_weights = list(list(zip(nws.tolist(), pws.tolist())))\n",
        "class_weights"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cDKW2WE8RpSu"
      },
      "outputs": [],
      "source": [
        "# file to load  train--val indeces from\n",
        "fp = os.path.join(input_path, 'cv_ids.json')\n",
        "with open(fp, 'r') as file:\n",
        "  cv_ids = json.load(file)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "mbeua6Byaw8V",
        "outputId": "664a65d2-db30-4bad-a8c4-27828763bad6"
      },
      "outputs": [],
      "source": [
        "for iter, ids in cv_ids.items():\n",
        "  print(iter, len(ids['train']), len(ids['val']))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "B5k9UrmlUt5n"
      },
      "outputs": [],
      "source": [
        "def remove_artifacts(path):\n",
        "  dirs = [nm for nm in os.listdir(path) if nm.startswith('checkpoint-')]\n",
        "  dirs.append('runs')\n",
        "  for dir in dirs:\n",
        "    shutil.rmtree(os.path.join(path, dir))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "logs_fp = os.path.join(res_dir, 'xlmt_finetuning_eval_log.tab')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000,
          "referenced_widgets": [
            "744642cd21b744758b33a215e953c51a",
            "ca576e6d10324d1fbdbdcf9bc83f2ca2",
            "9a02476e243542969255ef18700ce5e6",
            "b561f4f9971242d19b2bada693893050",
            "c18c288050b84116a09d65b7c9c066c9",
            "36b010cb7d99405cbbf0fdd2f7158dd0",
            "de5e6b894a5e4ed38b0d457f9a82f5f0",
            "c9deceda0f8741408a302b55406ff6b4",
            "d70f917895d8445db8d9def2964179ce",
            "a8cd11d04db34abdbb0b2a8afc32f628",
            "63c269ebd8c244eba490de2464b19688"
          ]
        },
        "id": "dXhjbnjp-9Dy",
        "outputId": "03e754ee-1a00-432c-c433-9037a7d9fdaf"
      },
      "outputs": [],
      "source": [
        "# iterate over CV train-val splits\n",
        "# NOTE: this is going to run for a while\n",
        "for iter, ids in cv_ids.items():\n",
        "  iter = int(iter)\n",
        "  train_idxs = ids['train']\n",
        "  val_idxs = ids['val']\n",
        "\n",
        "  # create training and validation dataset\n",
        "  train_dataset = create_dataset(\n",
        "    train_dat.loc[train_idxs].text.values.tolist(),\n",
        "    train_dat.loc[train_idxs].label_.values.tolist()\n",
        "  )\n",
        "  val_dataset = create_dataset(\n",
        "    train_dat.loc[val_idxs].text.values.tolist(),\n",
        "    train_dat.loc[val_idxs].label_.values.tolist()\n",
        "  )\n",
        "\n",
        "  # iterate over class weights tuples\n",
        "  logs = list()\n",
        "  for nw, pw in class_weights:\n",
        "    print(f'\\n__________\\nCV iteration nr. {iter}: training with positive class weight {pw}\\n\\n')\n",
        "    # instantiate\n",
        "    trainer = ClassWeightsTrainer(\n",
        "      model_init = model_init,\n",
        "      device = device,\n",
        "      args = training_args,\n",
        "      train_dataset = train_dataset,\n",
        "      eval_dataset = val_dataset,\n",
        "      compute_metrics = compute_metrics,\n",
        "      class_weights = {0: nw, 1: pw},\n",
        "      callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]\n",
        "    )\n",
        "    # train\n",
        "    trainer.train()\n",
        "    # get eval history\n",
        "    log = get_train_history(trainer)\n",
        "    # add CV iteration indicator\n",
        "    log['cv_iter'] = iter\n",
        "    # add current class weights\n",
        "    log['neg_class_weight'] = nw\n",
        "    log['pos_class_weight'] = pw\n",
        "    logs.append(log)\n",
        "    # clean up\n",
        "    remove_artifacts(path=training_args.output_dir)\n",
        "\n",
        "  # combine row-wise\n",
        "  logs = pd.concat(logs, axis = 0)\n",
        "\n",
        "  # write(append) to TSV\n",
        "  logs.to_csv(\n",
        "    logs_fp,                           # file path\n",
        "    sep = '\\t',                        # tab-separated (i.e., TSV)\n",
        "    index = False,                     # omit index\n",
        "    mode = 'w' if iter == 0 else 'a',  # append after first iteration\n",
        "    header = iter == 0                 # write header only first time\n",
        "  )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NjUMdfAVTT1U"
      },
      "source": [
        "Now we re-load the full log from the disk to figure out which class weighting scheme results in the classifiers with highests F1 in detecting the positive label class (when accounting for variability across CV train--val splits):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 443
        },
        "id": "cekKo0z2WDv9",
        "outputId": "36eb5689-408f-4a1f-ae82-fc3215346645"
      },
      "outputs": [],
      "source": [
        "# read full logs from disk\n",
        "logs = pd.read_csv(logs_fp, sep = '\\t')\n",
        "\n",
        "eval_metrics = ['eval_' + m for m in metrics]\n",
        "\n",
        "logs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 226
        },
        "id": "bq1C-kprW3XD",
        "outputId": "6056ae22-4f94-4949-ad9d-677f9ddc6176"
      },
      "outputs": [],
      "source": [
        "# function for summarizing eval metrics\n",
        "def print_sum_stat(x):\n",
        "  m = x.mean()\n",
        "  sd = x.std()\n",
        "  return f'{m:.3f} ± {sd:.3f}'\n",
        "\n",
        "# get max positive-class F1 values within CV iteration and class weight groups\n",
        "logs['max_eval_f1_pos'] = logs[['cv_iter', 'pos_class_weight', 'eval_f1_pos']].\\\n",
        "  groupby(['cv_iter', 'pos_class_weight']).\\\n",
        "  eval_f1_pos.\\\n",
        "  transform(max)\n",
        "\n",
        "# subset to rows with max positive-class F1 values within CV iteration and class weight groups\n",
        "best_perf = logs[logs.eval_f1_pos == logs.max_eval_f1_pos][['pos_class_weight'] + eval_metrics]\n",
        "\n",
        "# summarize across CV iterations\n",
        "avg_perfs = best_perf.set_index('pos_class_weight').\\\n",
        "  groupby('pos_class_weight').\\\n",
        "  transform(print_sum_stat).\\\n",
        "  drop_duplicates()\n",
        "\n",
        "avg_perfs"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t-wEvSaL-ruS"
      },
      "source": [
        "When averaging across CV iterations and taking into account variability in performance estimates, the data shows that changing class weights has, as expected, some impact on the classifiers positive class detection ability:\n",
        "positive-class recall is highest when weighing positive samples 4:1; but this comes at the cost of reduced precision.\n",
        "Hence, I use the 8:1 weighing scheme to train the final classifier."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "B64-fZ0oHXKa"
      },
      "source": [
        "## Train 'best' models and evaluate on held-out test data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "OhGUsjMXIhr2",
        "outputId": "1bfc571f-f891-4470-9de6-f4feb75bf562"
      },
      "outputs": [],
      "source": [
        "training_args.metric_for_best_model = 'f1_macro'\n",
        "\n",
        "# create training\n",
        "train_dataset = create_dataset(\n",
        "  train_dat.text.values.tolist(),\n",
        "  train_dat.label_.values.tolist()\n",
        ")\n",
        "\n",
        "# instantiate\n",
        "trainer = ClassWeightsTrainer(\n",
        "  model_init = model_init,\n",
        "  device = device,\n",
        "  class_weights = {0: .125, 1: 0.875}, # use 1:8 class weighting scheme\n",
        "  args = training_args,\n",
        "  train_dataset = train_dataset,\n",
        "  eval_dataset = test_dataset,\n",
        "  compute_metrics = compute_metrics,\n",
        "  callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]\n",
        ")\n",
        "\n",
        "# train\n",
        "trainer.train()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 245
        },
        "id": "zd5nkjQ7g6CW",
        "outputId": "0c632baf-0a66-41ed-d96e-11468c510fa1"
      },
      "outputs": [],
      "source": [
        "# get eval history\n",
        "log = get_train_history(trainer)\n",
        "\n",
        "# evaluate\n",
        "test_preds_raw, test_labels, res = trainer.predict(test_dataset)\n",
        "res = {m.removeprefix('test_'): v for m, v in res.items() if m.removeprefix('test_') in metrics}\n",
        "\n",
        "# save eval results\n",
        "res_fp = os.path.join(res_dir, 'xlmt_model_evaluated.json')\n",
        "# with open(res_fp, 'w') as f:\n",
        "#  json.dump(res, f)\n",
        "\n",
        "# print classification report\n",
        "rep = classification_report(\n",
        "    test_labels,\n",
        "    test_preds_raw.argmax(axis=1),\n",
        "    labels = [0, 1],\n",
        "    target_names = ['neg', 'pos']\n",
        ")\n",
        "print(rep)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eiCYTpXmlzil"
      },
      "outputs": [],
      "source": [
        "# save classification report\n",
        "res = classification_report(\n",
        "    test_dat.label_.values,\n",
        "    test_preds_raw.argmax(axis=1),\n",
        "    labels = [0, 1],\n",
        "    target_names = ['neg', 'pos'],\n",
        "    output_dict = True\n",
        ")\n",
        "res_fp = os.path.join(res_dir, 'xlmt_test_result.tab')\n",
        "pd.DataFrame(res).transpose().to_csv(res_fp, sep = '\\t', index = False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6skAM7CKSkWK",
        "outputId": "2df12502-a679-435d-f977-9acaf9bed169"
      },
      "outputs": [],
      "source": [
        "# save best model (and pre-trained XLM-T tokenizer)\n",
        "model_path = os.path.join(fits_path, 'xlmt-elitecriticism-classifier')\n",
        "os.makedirs(model_path, exist_ok=True)\n",
        "trainer.save_model(model_path)\n",
        "tokenizer.save_pretrained(model_path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NooBsRMbp9LE"
      },
      "source": [
        "## Clean up"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "W0BrOPCzpoh3",
        "outputId": "30666f39-e44a-4078-8739-e4176c6b0c34"
      },
      "outputs": [],
      "source": [
        "del trainer, train_dat, test_dat, train_dataset, val_dataset, test_dataset\n",
        "gc.collect()\n",
        "torch.cuda.empty_cache()\n",
        "gc.collect()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Qf9oE7bNmKdT"
      },
      "outputs": [],
      "source": [
        "shutil.rmtree('./temp')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rLR23BxIyC8g"
      },
      "source": [
        "<a id='ft_native'></a>"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.12"
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "36b010cb7d99405cbbf0fdd2f7158dd0": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "63c269ebd8c244eba490de2464b19688": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "744642cd21b744758b33a215e953c51a": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_ca576e6d10324d1fbdbdcf9bc83f2ca2",
              "IPY_MODEL_9a02476e243542969255ef18700ce5e6",
              "IPY_MODEL_b561f4f9971242d19b2bada693893050"
            ],
            "layout": "IPY_MODEL_c18c288050b84116a09d65b7c9c066c9"
          }
        },
        "9a02476e243542969255ef18700ce5e6": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_c9deceda0f8741408a302b55406ff6b4",
            "max": 1113236958,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_d70f917895d8445db8d9def2964179ce",
            "value": 1113236958
          }
        },
        "a8cd11d04db34abdbb0b2a8afc32f628": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "b561f4f9971242d19b2bada693893050": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_a8cd11d04db34abdbb0b2a8afc32f628",
            "placeholder": "​",
            "style": "IPY_MODEL_63c269ebd8c244eba490de2464b19688",
            "value": " 1.04G/1.04G [00:19&lt;00:00, 42.5MB/s]"
          }
        },
        "c18c288050b84116a09d65b7c9c066c9": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "c9deceda0f8741408a302b55406ff6b4": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ca576e6d10324d1fbdbdcf9bc83f2ca2": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_36b010cb7d99405cbbf0fdd2f7158dd0",
            "placeholder": "​",
            "style": "IPY_MODEL_de5e6b894a5e4ed38b0d457f9a82f5f0",
            "value": "Downloading: 100%"
          }
        },
        "d70f917895d8445db8d9def2964179ce": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "de5e6b894a5e4ed38b0d457f9a82f5f0": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
